Simple Tasks and Symmetry

using the e3nn repository

tutorial by: Tess E. Smidt (blondegeek)

There are some unintuitive consequences of using E(3) equivariant neural networks.

One example is that the symmetry your output has to be equal to or higher than the symmetry of your input. The following 3 simple tasks are to help demonstrate this:

  • Task 1: Distort a rectangle to a square.
  • Task 2: Distort a square to a rectangle.
  • Task 3: Distort a square to a rectangle -- with symmetry breaking (using representation theory).
  • Task 4: Distort a square to a rectangle -- with symmetry breaking (using gradients to change input).

We will see that we can quickly do Task 1, but not Task 2. Only by using symmetry breaking in Task 3 and Task 4 are we able to distort a square into a rectangle.

In [1]:
import torch
from functools import partial
import numpy as np

import e3nn
import e3nn.o3 as o3
from e3nn.point.operations import Convolution
from e3nn.non_linearities import GatedBlock
from e3nn.kernel import Kernel
from e3nn.kernel_mod import Kernel as KernelMod
from e3nn.radial import CosineBasisModel
from e3nn.non_linearities import rescaled_act

import matplotlib.pyplot as plt
%matplotlib inline

from e3nn.spherical_tensor import SphericalTensor

torch.set_default_dtype(torch.float64)
In [2]:
# Define out geometry
square = torch.tensor(
    [[0., 0., 0.], [1., 0., 0.], [1., 1., 0.], [0., 1., 0.]]
)
square -= square.mean(-2)
sx, sy = 0.5, 1.5
rectangle = square * torch.tensor([sx, sy, 0.])
rectangle -= rectangle.mean(-2)

N, _ = square.shape

markersize = 15

def plot_task(ax, start, finish, title, marker=None):
    ax.plot(torch.cat([start[:, 0], start[:, 0]]), 
            torch.cat([start[:, 1], start[:, 1]]), 'o-', 
            markersize=markersize + 5 if marker else markersize, 
            marker=marker if marker else 'o')
    ax.plot(torch.cat([finish[:, 0], finish[:, 0]]), 
            torch.cat([finish[:, 1], finish[:, 1]]), 'o-', markersize=markersize)
    for i in range(N):
        ax.arrow(start[i, 0], start[i, 1], 
                 finish[i, 0] - start[i, 0], 
                 finish[i, 1] - start[i, 1],
                 length_includes_head=True, head_width=0.05, facecolor="black", zorder=100)

    ax.set_title(title)
    ax.set_axis_off()

# fig, axes = plt.subplots(1, 3, figsize=(14, 6))
fig, axes = plt.subplots(1, 2, figsize=(9, 6))
plot_task(axes[0], rectangle, square, "Task 1: Rectangle to Square")
plot_task(axes[1], square, rectangle, "Task 2: Square to Rectangle")
# plot_task(axes[2], square, rectangle, "Task 3: Square to Rectangle with Symmetry Breaking", "$\u2B2E$")

In these tasks, we want to move 4 points in one configuration to another configuration. The input to the network will be the initial geometry and features on that geometry. The output will be used to signify "displacement" of each point to the new configuration. We can represent displacement in a couple different ways. The simplest way is to represent a displacement as an L=1 vector, Rs=[(1, 1]]. However, to better illustrate the symmetry properties of the network, we instead are going to use a spherical harmonic signal or more specifically, the peak of the spherical harmonic signal, to signify the displacement of the original point.

First, we set up a very basic network that has the same representation list Rs = [(1, L) for L in range(5 + 1)] throughout the entire network. The input will be a spherical tensor with representation Rs and the output will also be a spherical tensor with representation Rs. We will interpret the output of the network as a spherical harmonic signal where the peak location will signify the desired displacement.

For these examples, we will used the default e3nn.networks.GatedConvNetwork class for our model

In [3]:
from e3nn.networks import GatedConvNetwork
L_max = 5
Rs = [(1, L) for L in range(L_max + 1)]
Network = partial(GatedConvNetwork, Rs_in=Rs, Rs_hidden=Rs, Rs_out=Rs, lmax=L_max, max_radius=3.0, kernel=KernelMod)

Task 1: Distort a rectangle to square.

In this task, our input is a four points in the shape of a rectangle with simple scalars (1.0) at each point. The task is to learn to displace the points to form a (more symmetric) square.

In [4]:
model = Network()

params = model.parameters()
optimizer = torch.optim.Adam(params, 1e-2)
loss_fn = torch.nn.MSELoss()
In [5]:
input = torch.zeros(1, N, sum(2 * L + 1 for L in range(L_max + 1)))
input[:, :, 0] = 1.  # batch, point, channel

displacements = square - rectangle
N, _ = displacements.shape
projections = torch.stack([SphericalTensor.from_geometry(displacements[i].unsqueeze(0), L_max).signal for i in range(N)])
In [6]:
iterations = 100
for i in range(iterations):
    output = model(input, rectangle.unsqueeze(0))
    loss = loss_fn(output, projections.unsqueeze(0))
    if i % 10 == 0:
        print(loss)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
tensor(0.0030, grad_fn=<MseLossBackward>)
tensor(0.0002, grad_fn=<MseLossBackward>)
tensor(4.5192e-05, grad_fn=<MseLossBackward>)
tensor(1.4629e-05, grad_fn=<MseLossBackward>)
tensor(4.8856e-06, grad_fn=<MseLossBackward>)
tensor(1.4875e-06, grad_fn=<MseLossBackward>)
tensor(4.8386e-07, grad_fn=<MseLossBackward>)
tensor(2.1258e-07, grad_fn=<MseLossBackward>)
tensor(5.4593e-07, grad_fn=<MseLossBackward>)
tensor(4.0142e-07, grad_fn=<MseLossBackward>)
In [7]:
# Plot spherical harmonic projections
import plotly
from plotly.subplots import make_subplots
import plotly.graph_objects as go
In [8]:
def plot_output(start, finish, output, start_label, finish_label):
    rows, cols = 1, 1
    specs = [[{'is_3d': True} for i in range(cols)]
             for j in range(rows)]
    fig = make_subplots(rows=rows, cols=cols, specs=specs)
    fig.add_trace(go.Scatter3d(x=start[:, 0], y=start[:, 1], z=start[:, 2], mode="markers", name=start_label))
    fig.add_trace(go.Scatter3d(x=finish[:, 0], y=finish[:, 1], z=finish[:, 2], mode="markers", name=finish_label))
    for i in range(N):
        r, f = SphericalTensor(output[0][i].detach(), 1, L_max).plot(center=start[i])
        trace = go.Surface(x=r[..., 0], y=r[..., 1], z=r[..., 2], surfacecolor=f.numpy())
        trace.showscale = False
        fig.add_trace(trace, 1, 1)
    return fig
In [9]:
output = model(input, rectangle.unsqueeze(0))
fig = plot_output(rectangle, square, output, "Rectangle", "Square")
fig.update_layout(scene_aspectmode='data')
fig.show()

And let's check that it's equivariant

In [10]:
angles = torch.rand(3) * torch.tensor([np.pi, 2 * np.pi, np.pi])
rot = o3.rot(*angles)
rot_rectangle = torch.einsum('xy,jy->jx', (rot, rectangle))
rot_square = torch.einsum('xy,jy->jx', (rot, square))
output = model(input, rot_rectangle.unsqueeze(0))
fig = plot_output(rot_rectangle, rot_square, output, "Rectangle", "Square")
fig.update_layout(scene_aspectmode='data')
fig.show()

Task 2: Now the reverse! Distort a square to rectangle.

In this task, our input is a four points in the shape of a square with simple scalars (1.0) at each point. The task is to learn to displace the points to form a (less symmetric) rectangle. Can the network learn this task?

In [11]:
model = Network()

params = model.parameters()
optimizer = torch.optim.Adam(params, 1e-2)
loss_fn = torch.nn.MSELoss()
In [12]:
input = torch.zeros(1, N, sum(2 * L + 1 for L in range(L_max + 1)))
input[:, :, 0] = 1.  # batch, point, channel

displacements = rectangle - square
N, _ = displacements.shape
projections = torch.stack([SphericalTensor.from_geometry(displacements[i].unsqueeze(0), L_max).signal for i in range(N)])
In [13]:
iterations = 100
for i in range(iterations):
    output = model(input, square.unsqueeze(0))
    loss = loss_fn(output, projections.unsqueeze(0))
    if i % 10 == 0:
        print(loss)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
tensor(0.0036, grad_fn=<MseLossBackward>)
tensor(0.0008, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)

Hmm... seems to get stuck. Let's try more iterations.

In [14]:
iterations = 100
for i in range(iterations):
    output = model(input, square.unsqueeze(0))
    loss = loss_fn(output, projections.unsqueeze(0))
    if i % 10 == 0:
        print(loss)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)

It's stuck. What's going on?

In [15]:
fig = plot_output(square, rectangle, output, "Square", "Rectangle")
fig.update_layout(scene_aspectmode='data')
fig.show()

The symmetry of the output must be higher or equal to the symmetry of the input!

To be able to do this task, you need to give the network more information -- information that breaks the symmetry to that of the desired output. The square has a point group of $D_{4h}$ (16 elements) while the rectangle has a point group of $D_{2h}$ (8 elements).

A technical note (for those who are interested).

In this example, we are NOT using a network equivariant to parity) -- that will be in another update / tutorial -- so we are actually only sensitive to the fact that the square has $C_4$ symmetry while the rectangle has $C_2$ symmetry.

Task 3: Fixing Task 2. Distort a square into a rectangle -- now, with symmetry breaking (using representation theory)!

In this task, our input is four points in the shape of a square with simple scalars (1.0) AND a contribution for the $x^2 - y^2$ feature at each point. The task is to learn to displace the points to form a (less symmetric) rectangle. Can the network learn this task?

In [16]:
model = Network()

params = model.parameters()
optimizer = torch.optim.Adam(params, 1e-2)
loss_fn = torch.nn.MSELoss()
In [17]:
input = torch.zeros(1, N, sum(2 * L + 1 for L in range(L_max + 1)))
input[:, :, 0] = 1.  # batch, point, channel
# Breaking x and y symmetry with x^2 - y^2 component
input[:, :, 8] = 0.1  # x^2 - y^2

displacements = rectangle - square
N, _ = displacements.shape
projections = torch.stack([SphericalTensor.from_geometry(displacements[i].unsqueeze(0), L_max).signal for i in range(N)])
In [18]:
iterations = 100
for i in range(iterations):
    output = model(input, square.unsqueeze(0))
    loss = loss_fn(output, projections.unsqueeze(0))
    if i % 10 == 0:
        print(loss)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
tensor(0.0023, grad_fn=<MseLossBackward>)
tensor(0.0005, grad_fn=<MseLossBackward>)
tensor(8.1627e-05, grad_fn=<MseLossBackward>)
tensor(1.9512e-05, grad_fn=<MseLossBackward>)
tensor(1.0131e-05, grad_fn=<MseLossBackward>)
tensor(2.9695e-06, grad_fn=<MseLossBackward>)
tensor(8.9779e-07, grad_fn=<MseLossBackward>)
tensor(3.5140e-07, grad_fn=<MseLossBackward>)
tensor(1.5716e-07, grad_fn=<MseLossBackward>)
tensor(4.7280e-08, grad_fn=<MseLossBackward>)
In [19]:
fig = plot_output(square, rectangle, output, "Square", "Rectangle")
fig.update_layout(scene_aspectmode='data')
fig.show()

What is $x^2 - y^2$ the term doing? It's breaking the symmetry along the $\hat{x}$ and $\hat{y}$ directions.

Notice how the shape below is an ellisoid elongated in the y direction and squished in the x. This isn't the only pertubation we could've added, but it is the most symmetric.

In [21]:
rows, cols = 1, 1
specs = [[{'is_3d': True} for i in range(cols)]
         for j in range(rows)]
fig = make_subplots(rows=rows, cols=cols, specs=specs)

L_max = 5
Rs = [(1, L) for L in range(L_max + 1)]
sum_Ls = sum(2 * L + 1 for mult, L in Rs) 

# Random spherical tensor up to L_Max
signal = torch.zeros(sum_Ls)
signal[0] = 1
# Breaking x and y symmetry with x^2 - y^2
signal[8] = -0.1

sphten = SphericalTensor(signal, 1, L_max)

r, f = sphten.plot(relu=False, n=60)
trace = go.Surface(x=r[..., 0], y=r[..., 1], z=r[..., 2], surfacecolor=f.numpy())
fig.add_trace(trace, row=1, col=1)
fig.show()

Sure, but where did the $x^2 - y^2$ come from?

It's a bit of a complicated story, but at the surface level here it is: Character tables are handy tabulations of how certain spherical tensor datatypes transform under that group symmetry. The rows are irreducible representations (irrep for short) and the columns are similar elements of the group (called conjugacy classes). Character tables are most commonly seen for finite groups of $E(3)$ symmetry as they are used extensively in solid state physics, crystallography, chemistry, etc. Next to the part of the table with the "characters", there are often columns showing linear, quadratic, and cubic functions (meaning they are of order 1, 2, and 3) that transform in the same way as a given irrep.

So, a square has a point group symmetry of $D_{4h}$ while a rectangle has a point group symmetry of $D_{2h}$

If we look at column headers of character tables for $D_{4h}$ and $D_{2h}$...

... we can see that the irrep $B_{1g}$ of $D_{4h}$ that has -1's in the columns for all the symmetry operations that $D_{2h}$ DOESN'T have and if we look down that row to the column "quadratic functions" we see, voila $x^2 - y^2$. So, to break all those symmetries that $D_{4h}$ has that $D_{2h}$ DOESN'T have -- we add a non-zero contribution to the $x^2 - y^2$ component of our spherical harmonic tensors.

WARNING: Character tables are written down with specific coordinate system conventions. For example, the $\hat{z}$ axis always points along the highest symmetry axis, $\hat{y}$ along the next highest, etc. We have specifically set up our problem have a coordinate frame that matches these conventions.

A technical note (for those who are interested).

Again, in this example (because we are choosing to leave out parity), we are only sensitive to the fact that the square has $C_4$ symmetry while the rectangle has $C_2$ symmetry. However, you can check the character tables for the point groups $C_4$ and $C_2$ to see that the arguement above still holds for the $x^2 - y^2$ order parameter.

Task 4: Fixing Task 2 without having to read character tables like Task 4. Distort a square into a rectangle -- now, with symmetry breaking (using gradients to change the input)!

In this task, our input is four points in the shape of a square with simple scalars (1.0) AND then we LEARN how to change the inputs to break symmetry such that we can fit a better model.

In [22]:
model = Network()

params = model.parameters()
optimizer = torch.optim.Adam(params, 1e-2)
loss_fn = torch.nn.MSELoss()

input = torch.zeros(1, N, sum(2 * L + 1 for L in range(L_max + 1)))
input[:, :, 0] = 1.  # batch, point, channel

displacements = rectangle - square
N, _ = displacements.shape
projections = torch.stack([SphericalTensor.from_geometry(displacements[i].unsqueeze(0), L_max).signal for i in range(N)])
input.requires_grad = True

input_optimizer = torch.optim.Adam([input], 1e-3)
input_loss_fn = torch.nn.MSELoss()
In [23]:
displacements = rectangle - square
N, _ = displacements.shape
projections = torch.stack([SphericalTensor.from_geometry(displacements[i].unsqueeze(0), L_max).signal for i in range(N)])
projections = projections.unsqueeze(0)

First, we'll train the model until it gets stuck.

In [24]:
iterations = 201
for i in range(iterations):
    output = model(input, square.unsqueeze(0))
    loss = loss_fn(output, projections)
    if i % 30 == 0:
        print(loss)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
tensor(0.0044, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)

This gets stuck like before. So let's try alternating between updating our input and updating the model.

In [25]:
iterations = 101
eps = 1e-6
for i in range(iterations):
    output = model(input, square.unsqueeze(0))
    loss = loss_fn(output, projections)
    if i % 10 == 0:
        print('model loss: ', loss)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    output = model(input, square.unsqueeze(0))
    # This is the regular loss for the model
    loss = input_loss_fn(output, projections)
    # This is the loss for keeping the changes to the input small
    loss += ((input[:, :, 0:1] - torch.ones_like(input[:, :, 0:1])).abs()).mean()
    loss += ((input[:, :, 9:]).abs()).mean()
    loss += ((input[:, :, 1:4]).abs()).mean()
    # Prefer features on atoms to be the same (global parameter)
    loss += ((input[:, :, 4:9] - input[:, 0, 4:9])**2).mean()
    # and add a mild L1 penalty for the L=2 output.
    loss += 1e-3 * ((input[:, :, 4:9]).abs()).mean()
    
    if i % 20 == 0:
        print('input loss: ', loss)
    input_optimizer.zero_grad()
    loss.backward()
    input_optimizer.step()
model loss:  tensor(0.0007, grad_fn=<MseLossBackward>)
input loss:  tensor(0.0007, grad_fn=<AddBackward0>)
model loss:  tensor(0.0007, grad_fn=<MseLossBackward>)
model loss:  tensor(0.0006, grad_fn=<MseLossBackward>)
input loss:  tensor(0.0012, grad_fn=<AddBackward0>)
model loss:  tensor(0.0001, grad_fn=<MseLossBackward>)
model loss:  tensor(0.0001, grad_fn=<MseLossBackward>)
input loss:  tensor(0.0006, grad_fn=<AddBackward0>)
model loss:  tensor(2.5280e-05, grad_fn=<MseLossBackward>)
model loss:  tensor(6.5288e-06, grad_fn=<MseLossBackward>)
input loss:  tensor(0.0004, grad_fn=<AddBackward0>)
model loss:  tensor(3.3903e-06, grad_fn=<MseLossBackward>)
model loss:  tensor(8.6898e-07, grad_fn=<MseLossBackward>)
input loss:  tensor(0.0004, grad_fn=<AddBackward0>)
model loss:  tensor(3.3313e-07, grad_fn=<MseLossBackward>)
model loss:  tensor(1.8987e-07, grad_fn=<MseLossBackward>)
input loss:  tensor(0.0004, grad_fn=<AddBackward0>)

If we examine the input, we should see that the only components that are (largely) non-zero are the scalar features (which are all 1's) and the L=2 feature corresponding to $x^2 - y^2$, which is the 5th element of the L=2 array.

In [26]:
round_decimal = 3
print("L=0 ")
print(input.detach().numpy().round(round_decimal)[:, :, 0])
print("L=1")
print(input.detach().numpy().round(round_decimal)[:, :, 1: 1 + 3])
print("L=2")
print(input.detach().numpy().round(round_decimal)[:, :, 4: 4 + 5])
print("L=3")
print(input.detach().numpy().round(round_decimal)[:, :, 9: 9 + 7])
print("L=4")
print(input.detach().numpy().round(round_decimal)[:, :, 16: 16 + 9])
print("L=5")
print(input.detach().numpy().round(round_decimal)[:, :, 25: 25 + 11])
L=0 
[[1. 1. 1. 1.]]
L=1
[[[-0.  0.  0.]
  [-0. -0. -0.]
  [ 0.  0. -0.]
  [ 0. -0.  0.]]]
L=2
[[[ 0.    -0.    -0.    -0.    -0.025]
  [ 0.    -0.     0.     0.    -0.025]
  [-0.     0.     0.     0.    -0.025]
  [ 0.    -0.    -0.    -0.    -0.025]]]
L=3
[[[-0.  0.  0.  0.  0.  0. -0.]
  [-0.  0. -0. -0. -0. -0.  0.]
  [ 0.  0.  0.  0. -0.  0.  0.]
  [ 0.  0. -0. -0. -0. -0. -0.]]]
L=4
[[[ 0.  0. -0. -0.  0. -0.  0.  0.  0.]
  [-0. -0.  0. -0.  0.  0.  0.  0.  0.]
  [-0. -0. -0. -0.  0. -0.  0. -0.  0.]
  [-0.  0.  0.  0.  0.  0.  0. -0.  0.]]]
L=5
[[[ 0.  0. -0.  0. -0. -0.  0.  0. -0.  0.  0.]
  [-0.  0.  0.  0. -0.  0. -0. -0.  0. -0. -0.]
  [ 0.  0. -0.  0.  0. -0. -0. -0. -0.  0. -0.]
  [ 0.  0. -0.  0.  0.  0.  0.  0. -0. -0.  0.]]]

This plot shows what the new input looks like. It's similar to the above plot from Task 3.

In [27]:
fig = plot_output(square, square, input, '', '')
fig.update_layout(scene_aspectmode='data')
fig.show()
In [ ]: